This file displays RSA regression weights for the Schaefer atlas 7-Network, 200-parcel.
## libraries ----
library(here)
library(data.table)
source(here("src", "stroop-rsa-pc.R"))
library(knitr)
library(dplyr)
library(tidyr)
library(ggplot2)
library(abind)
library(reticulate)
library(mfutils)
## settings, params ----
opts_chunk$set(echo = TRUE)
theme_set(theme_bw(base_size = 10))
# knit_engines$set(python = eng_python)
if (exists("params")) {
ttype_subset <- params$ttype_subset
prewh <- params$prewh
measure <- params$measure
} else {
ttype_subset <- "bias"
prewh <- "none"
measure <- "crcor"
}
## constants ----
wd <- here()
subjlists <- c("mc1", "mc2", "mi1", "mi2")
seswaves <- c("baseline_wave1", "baseline_wave2", "proactive_wave1", "proactive_wave2")
glmname <- "lsall_1rpm"
roiset <- "Schaefer2018Parcel200"
sessions <- c("baseline", "proactive")
## data ----
## read regression weights:
dat <- Map(
function(subjlist, suffix) {
fname <- construct_filename_weights(
measure = measure, subjlist = subjlist, glmname = glmname,
ttype_subset = ttype_subset, roiset = roiset, prewh = prewh, suffix = suffix
)
fread(fname)
},
subjlist = subjlists,
suffix = paste0("__seswave-", seswaves)
)
dat <- rbindlist(dat)
## calculate ----
## for group stats:
dat_sum <- dat %>%
group_by(session, term, wave, roi) %>%
summarize(
m = mean(b),
t_stat = t.test(b)$statistic,
p = t.test(b, alternative = "greater")$p.value,
.groups = "drop_last"
) %>%
mutate(p_fdr = p.adjust(p, "fdr"))
## bind atlas info:
k <- schaefer2018_7_200_fsaverage5$key %>% mutate(roi = gsub("7Networks_", "", id))
dat_sum <- left_join(dat_sum, k, by = "roi")
if r.roiset == "Schaefer2018Parcel200":
roiset = "200Parcels7Networks"
nparc = 200
exec(open(r.wd+"/src/setup.py").read())
Number of subjects per wave and session:
table(select(unique(select(dat, subject, wave, session)), wave, session))
## session
## wave baseline proactive
## wave1 90 90
## wave2 40 40
NB: includes twins!
dat_sum %>%
filter(term %in% c("distractor", "incongruency", "target")) %>%
group_by(session, wave, term) %>%
mutate(is_sig = p_fdr < 0.05) %>%
summarize(n_parcel = sum(is_sig), .groups = "drop_last") %>%
ggplot(aes(interaction(session, wave, sep = " "), n_parcel, fill = term)) +
geom_col(width = 0.5, position = position_dodge(width = 0.5), color = "black") +
scale_fill_brewer(type = "qual", palette = 1) +
labs(y = "number of parcels meeting threshold", x = "session*wave", fill = "model", title = "threshold: p_fdr < 0.05") +
scale_y_continuous(limits = c(0, 200)) +
theme(legend.position = c(0.7, 0.8))
dat_sum %>%
filter(term %in% c("distractor", "incongruency", "target")) %>%
group_by(session, wave) %>%
slice_max(order_by = -t_stat, n = 20) %>%
group_by(session, wave, term) %>%
summarize(n_parcel = n(), .groups = "drop_last") %>%
ggplot(aes(interaction(session, wave, sep = " "), n_parcel, fill = term)) +
geom_col(width = 0.5, position = position_dodge(width = 0.5), color = "black") +
scale_fill_brewer(type = "qual", palette = 1) +
labs(y = "number of parcels meeting threshold", x = "session*wave", fill = "model", title = "threshold: parcels with top 10% of t_stats") +
scale_y_continuous(limits = c(0, 20)) +
theme(legend.position = c(0.7, 0.8))
dat_sum %>%
filter(term %in% c("distractor", "incongruency", "target")) %>%
group_by(session, wave, term, network) %>%
mutate(is_sig = p_fdr < 0.05) %>%
summarize(n_parcel = sum(is_sig), .groups = "drop_last") %>%
ggplot(aes(interaction(session, wave, sep = " "), n_parcel, fill = term)) +
geom_col(width = 0.5, position = position_dodge(width = 0.5), color = "black") +
scale_fill_brewer(type = "qual", palette = 1) +
labs(y = "number of parcels meeting threshold", x = "session*wave", fill = "model", title = "threshold: p_fdr < 0.05") +
theme(legend.position = "none") +
facet_grid(vars(network))
dat_sum %>%
filter(term %in% c("distractor", "incongruency", "target")) %>%
group_by(session, wave, term) %>%
mutate(is_sig = cume_dist(t_stat) > 0.9) %>%
group_by(session, wave, term, network) %>%
summarize(n_parcel = sum(is_sig), .groups = "drop_last") %>%
ggplot(aes(interaction(session, wave, sep = " "), n_parcel, fill = term)) +
geom_col(width = 0.5, position = position_dodge(width = 0.5), color = "black") +
scale_fill_brewer(type = "qual", palette = 1) +
labs(y = "number of parcels meeting threshold", x = "session*wave", fill = "model", title = "parcels with top 10% of t_stats") +
theme(legend.position = "none") +
facet_grid(vars(network))
for term_i in range(len(r.models["crcor"])):
for wave_i in range(len(r.waves)):
for session_i in range(len(r.sessions)):
# session_i = 1
# term_i = 3
# wave_i = 0
is_session_i = r.dat_sum.session == r.sessions[session_i]
is_term_i = r.dat_sum.term == r.models["crcor"][term_i]
is_wave_i = r.dat_sum.wave == r.waves[wave_i]
d = r.dat_sum.loc[is_session_i & is_term_i & is_wave_i, :].copy()
vmax = np.max(r.dat_sum.m[is_term_i]) ## set color range per model (term)
vmin = np.min(r.dat_sum.m[is_term_i])
overlay_lh = get_overlay(d.idx, d.m, "left", nparc)
overlay_rh = get_overlay(d.idx, d.m, "right", nparc)
overlay = np.column_stack(np.stack((overlay_lh, overlay_rh)))
fig = plot_surf_roi_montage(
roi_map = overlay,
title = r.sessions[session_i] +" "+ r.models["crcor"][term_i] +" "+ r.waves[wave_i],
vmax = vmax,
vmin = vmin,
cbar_vmax = vmax,
cbar_vmin = vmin,
bg_on_data = True
)
plt.show()
fig.clear()
plt.close('all')
for term_i in range(len(r.models["crcor"])):
for wave_i in range(len(r.waves)):
for session_i in range(len(r.sessions)):
is_session_i = r.dat_sum.session == r.sessions[session_i]
is_term_i = r.dat_sum.term == r.models["crcor"][term_i]
is_wave_i = r.dat_sum.wave == r.waves[wave_i]
d = r.dat_sum.loc[is_session_i & is_term_i & is_wave_i, :].copy()
vmax = np.max(r.dat_sum.m[is_term_i]) ## set color range per model (term)
vmin = np.min(r.dat_sum.m[is_term_i])
d.loc[d.p_fdr > 0.05, 'm'] = 0
overlay_lh = get_overlay(d.idx, d.m, "left", nparc)
overlay_rh = get_overlay(d.idx, d.m, "right", nparc)
overlay = np.column_stack(np.stack((overlay_lh, overlay_rh)))
fig = plot_surf_roi_montage(
roi_map = overlay,
title = r.sessions[session_i] +" "+ r.models["crcor"][term_i] +" "+ r.waves[wave_i],
vmax = vmax,
vmin = vmin,
cbar_vmax = vmax,
cbar_vmin = vmin,
bg_on_data = True
)
plt.show()
fig.clear()
plt.close('all')
for term_i in range(len(r.models["crcor"])):
for wave_i in range(len(r.waves)):
for session_i in range(len(r.sessions)):
is_session_i = r.dat_sum.session == r.sessions[session_i]
is_term_i = r.dat_sum.term == r.models["crcor"][term_i]
is_wave_i = r.dat_sum.wave == r.waves[wave_i]
d = r.dat_sum.loc[is_session_i & is_term_i & is_wave_i, :].copy()
vmax = np.max(r.dat_sum.t_stat[is_term_i]) ## set color range per model (term)
vmin = np.min(r.dat_sum.t_stat[is_term_i])
overlay_lh = get_overlay(d.idx, d.t_stat, "left", nparc)
overlay_rh = get_overlay(d.idx, d.t_stat, "right", nparc)
overlay = np.column_stack(np.stack((overlay_lh, overlay_rh)))
fig = plot_surf_roi_montage(
roi_map = overlay,
title = r.sessions[session_i] +" "+ r.models["crcor"][term_i] +" "+ r.waves[wave_i],
vmax = vmax,
vmin = vmin,
cbar_vmax = vmax,
cbar_vmin = vmin,
bg_on_data = True
)
plt.show()
fig.clear()
plt.close('all')
for term_i in range(len(r.models["crcor"])):
for wave_i in range(len(r.waves)):
for session_i in range(len(r.sessions)):
is_session_i = r.dat_sum.session == r.sessions[session_i]
is_term_i = r.dat_sum.term == r.models["crcor"][term_i]
is_wave_i = r.dat_sum.wave == r.waves[wave_i]
d = r.dat_sum.loc[is_session_i & is_term_i & is_wave_i, :].copy()
vmax = np.max(r.dat_sum.t_stat[is_term_i]) ## set color range per model (term)
vmin = np.min(r.dat_sum.t_stat[is_term_i])
d.loc[d.p_fdr > 0.05, 't_stat'] = 0
overlay_lh = get_overlay(d.idx, d.t_stat, "left", nparc)
overlay_rh = get_overlay(d.idx, d.t_stat, "right", nparc)
overlay = np.column_stack(np.stack((overlay_lh, overlay_rh)))
fig = plot_surf_roi_montage(
roi_map = overlay,
title = r.sessions[session_i] +" "+ r.models["crcor"][term_i] +" "+ r.waves[wave_i],
vmax = vmax,
vmin = vmin,
cbar_vmax = vmax,
cbar_vmin = vmin,
bg_on_data = True
)
plt.show()
fig.clear()
plt.close('all')